"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import torch
import torch.nn.functional as F
from copy import deepcopy
from torch import nn

from ...common.registry import registry
from ..base_model import MomentumDistilationMixin
from ..med import XBertEncoder
from ..vit import VisionTransformerEncoder
from .blip import BlipBase
from .blip_outputs import BlipIntermediateOutput, BlipOutputWithLogits


@registry.register_model('blip_classification')
class BlipClassification(BlipBase, MomentumDistilationMixin):
    PRETRAINED_MODEL_CONFIG_DICT = {
        'base': 'configs/models/blip_classification_base.yaml',
    }

    def __init__(
        self,
        image_encoder,
        text_encoder,
        num_classes,
        momentum=0.995,
        alpha=0.4,
        max_txt_len=40,
        use_distill=True,
    ):
        super().__init__()

        self.tokenizer = self.init_tokenizer()

        self.use_distill = use_distill

        self.visual_encoder = image_encoder
        self.text_encoder = text_encoder

        hidden_size = text_encoder.config.hidden_size
        self.cls_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes),
        )

        if self.use_distill:
            self.visual_encoder_m = deepcopy(self.visual_encoder)
            self.text_encoder_m = deepcopy(self.text_encoder)
            self.cls_head_m = deepcopy(self.cls_head)

            self.momentum = momentum
            self.alpha = alpha

            self.model_pairs = [
                [self.visual_encoder, self.visual_encoder_m],
                [self.text_encoder, self.text_encoder_m],
                [self.cls_head, self.cls_head_m],
            ]

            self.copy_params()

        self.max_txt_len = max_txt_len

    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
        return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)

    def forward(self, samples, is_train=True):
        sentences = samples['text_input']
        sentences = self.tokenizer(
            sentences,
            padding='longest',
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors='pt',
        ).to(self.device)
        samples.update({'tokenized_text': sentences})

        targets = samples['label']

        image_embeds = self.visual_encoder.forward_features(samples['image'])
        encoder_output = self.text_encoder.forward_automask(samples['tokenized_text'], image_embeds)

        prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])

        if is_train:
            if self.use_distill:
                with torch.no_grad():
                    self._momentum_update()

                    image_embeds_m = self.visual_encoder_m(samples['image'])
                    encoder_output_m = self.text_encoder_m.forward_automask(samples['tokenized_text'], image_embeds_m)

                    prediction_m = self.cls_head_m(encoder_output_m.last_hidden_state[:, 0, :])

                alpha = self.alpha * self._rampup_factor(
                    epoch=samples['epoch'],
                    iters=samples['iters'],
                    num_iters_per_epoch=samples['num_iters_per_epoch'],
                )

                loss = (1 - alpha) * F.cross_entropy(prediction, targets) - alpha * torch.sum(
                    F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),
                    dim=1,
                ).mean()
            else:
                loss = F.cross_entropy(prediction, targets)

            # return {"loss": loss}
            return BlipOutputWithLogits(
                loss=loss,
                intermediate_output=BlipIntermediateOutput(
                    image_embeds=image_embeds,
                    image_embeds_m=image_embeds_m,
                    encoder_output=encoder_output,
                    encoder_output_m=encoder_output_m,
                ),
                logits=prediction,
                logits_m=prediction_m,
            )

        else:
            return {'predictions': prediction, 'targets': targets}

    def predict(self, samples):
        output = self.forward(samples, is_train=False)
        return output

    @classmethod
    def from_config(cls, cfg=None):
        image_encoder = VisionTransformerEncoder.from_config(cfg)

        # text encoder + multimodal encoder
        text_encoder = XBertEncoder.from_config(cfg)
        use_distill = cfg.get('use_distill', True)
        momentum = cfg.get('momentum', 0.995)
        num_classes = cfg.get('num_classes', -1)
        alpha = cfg.get('alpha', 0.4)
        max_txt_len = cfg.get('max_txt_len', 40)

        assert num_classes > 1, 'Invalid number of classes provided, found {}'.format(num_classes)

        model = cls(
            image_encoder=image_encoder,
            text_encoder=text_encoder,
            use_distill=use_distill,
            alpha=alpha,
            num_classes=num_classes,
            momentum=momentum,
            max_txt_len=max_txt_len,
        )

        # load pre-trained weights
        pretrain_path = cfg.get('pretrained', None)
        if pretrain_path is not None:
            msg = model.load_from_pretrained(url_or_filename=pretrain_path)

        return model
